
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

from .utils.utils import get_activation_function, make_conv_block, make_fc_block


class Bottleneck(nn.Module):
    def __init__(self, nChannels, growthRate, oper_order, activation_generator):
        super(Bottleneck, self).__init__()
        interChannels = 4*growthRate

        self.conv1 = make_conv_block(nChannels, interChannels, activation_generator=activation_generator,
                                     kernel_size=1, padding=0,
                                     oper_order=oper_order['full'])
        self.conv2 = make_conv_block(interChannels, growthRate, activation_generator=activation_generator,
                                     kernel_size=3, padding=1,
                                     oper_order=oper_order['full'])

    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2(out)
        out = torch.cat((x, out), 1)
        return out


class SingleLayer(nn.Module):
    def __init__(self, nChannels, growthRate, oper_order, activation_generator):
        super(SingleLayer, self).__init__()
        self.conv1 = make_conv_block(nChannels, growthRate, activation_generator=activation_generator,
                                     kernel_size=3, padding=1,
                                     oper_order=oper_order['full'])

    def forward(self, x):
        out = self.conv1(x)
        out = torch.cat((x, out), 1)
        return out


class Transition(nn.Module):
    def __init__(self, nChannels, nOutChannels, oper_order, activation_generator):
        super(Transition, self).__init__()

        self.conv1 = make_conv_block(nChannels, nOutChannels, activation_generator=activation_generator,
                                     kernel_size=1, padding=0,
                                     oper_order=oper_order['full'])

    def forward(self, x):
        out = self.conv1(x)
        out = F.avg_pool2d(out, 2)
        return out


class DenseNet(nn.Module):
    def __init__(self, growthRate, depth, reduction, nClasses, bottleneck,
                 activation_type, oper_order, dataset):
        super(DenseNet, self).__init__()

        self.activation_generator = get_activation_function(activation_type)
        self.oper_order = oper_order
        self.dataset = dataset
        self.oper_order = {'full': list(oper_order)[:]}

        nDenseBlocks = (depth-4) // 3
        if bottleneck:
            nDenseBlocks //= 2

        nChannels = 2*growthRate
        self.conv1 = nn.Conv2d(3, nChannels, kernel_size=3, padding=1,
                               bias=False)
        self.dense1 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck)
        nChannels += nDenseBlocks*growthRate
        nOutChannels = int(math.floor(nChannels*reduction))
        self.trans1 = Transition(nChannels, nOutChannels, self.oper_order, self.activation_generator)

        nChannels = nOutChannels
        self.dense2 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck)
        nChannels += nDenseBlocks*growthRate
        nOutChannels = int(math.floor(nChannels*reduction))
        self.trans2 = Transition(nChannels, nOutChannels, self.oper_order, self.activation_generator)

        nChannels = nOutChannels
        self.dense3 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck)
        nChannels += nDenseBlocks*growthRate

        self.bn1 = nn.BatchNorm2d(nChannels)
        self.fc = nn.Linear(nChannels, nClasses)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.bias.data.zero_()

    def _make_dense(self, nChannels, growthRate, nDenseBlocks, bottleneck):
        layers = []
        for i in range(int(nDenseBlocks)):
            if bottleneck:
                layers.append(Bottleneck(nChannels, growthRate
                                         , self.oper_order, self.activation_generator))
            else:
                layers.append(SingleLayer(nChannels, growthRate
                                          , self.oper_order, self.activation_generator))
            nChannels += growthRate
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv1(x)
        out = self.trans1(self.dense1(out))
        out = self.trans2(self.dense2(out))
        out = self.dense3(out)
        out = torch.squeeze(F.avg_pool2d(F.relu(self.bn1(out)), 8))
        out = F.log_softmax(self.fc(out))
        return out